"""
Phase 12: Adjustment Channels Within Monetary Unions
=====================================================
When the exchange rate is removed, what adjusts? This phase decomposes
within-union adjustment into identifiable channels:

1. CA Component Decomposition — trade vs income vs transfers within EMU/CFA
2. Savings vs Investment — which side of S-I adjusts to demographic pressure?
3. Bilateral Capital Flows — do within-EMU demographic differences predict
   portfolio reallocation (CPIS debt, equity, FDI)?
4. Fiscal Channel — does fiscal balance mediate the demographic-CA effect?
5. Competitiveness Proxies — output per worker, growth as wage adjustment proxies
6. Gross Position Adjustment — debt vs equity vs FDI external positions

This is the DD extension from the research agenda: "Monetary Union Adjustment
Without the Exchange Rate."
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

MULTILATERAL_DATA = ROOT_DIR / "multilateral" / "followup" / "data" / "processed"
BILATERAL_DATA = ROOT_DIR / "gravity_bilateral" / "data" / "processed"

# ── Monetary Union Definitions (from phase 10) ──────────────────────────

EUROZONE_JOIN = {
    'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
    'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
    'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
    'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
}
EUROZONE_ISO3 = set(EUROZONE_JOIN.keys())

CFA_JOIN = {
    'BEN': 1960, 'BFA': 1960, 'CIV': 1960, 'MLI': 1984,
    'NER': 1960, 'SEN': 1960, 'TGO': 1960, 'GNB': 1997,
    'CMR': 1960, 'CAF': 1960, 'TCD': 1960, 'COG': 1960,
    'GNQ': 1985, 'GAB': 1960,
}
CFA_ISO3 = set(CFA_JOIN.keys())

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    return f"{val:.3f}{stars(p)}", f"({se:.3f})"


def run_gls(df, y_var, x_vars, label):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 30:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None
    n_countries = sub['iso3'].nunique()
    if n_countries < 3:
        print(f"  {label}: insufficient countries ({n_countries}), skipping")
        return None
    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None
    result = {
        'model': label, 'dep_var': y_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
    }
    for i, var in enumerate(x_vars):
        result[f'{var}_coef'] = gls.beta[i]
        result[f'{var}_se'] = gls.se[i]
        result[f'{var}_p'] = gls.pvalues[i]
    return result


def run_bilateral_gls(df, y_var, x_vars, label):
    """Run PanelGLS for bilateral data using pair_id as entity."""
    cols = [y_var] + x_vars + ['pair_id', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None
    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['pair_id'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None
    result = {
        'model': label, 'dep_var': y_var,
        'n_obs': gls.n_obs, 'n_pairs': gls.n_countries,
        'r_squared': gls.r_squared,
    }
    for i, var in enumerate(x_vars):
        result[f'{var}_coef'] = gls.beta[i]
        result[f'{var}_se'] = gls.se[i]
        result[f'{var}_p'] = gls.pvalues[i]
    return result


def filter_union(df, members_dict, col='iso3'):
    """Filter panel to union members during membership years."""
    masks = []
    for iso, join_yr in members_dict.items():
        masks.append((df[col] == iso) & (df['year'] >= join_yr))
    if not masks:
        return df.iloc[0:0]
    return df[pd.concat([m for m in masks], axis=1).any(axis=1)].copy()


def add_z_deviations(df, z_vars=['Z_1', 'Z_2', 'Z_3']):
    """Add within-union Z deviations (cross-sectional demeaning per year)."""
    for z in z_vars:
        if z in df.columns:
            df[f'{z}_dev'] = df.groupby('year')[z].transform(lambda x: x - x.mean())
    return df


def write_table(rows, headers, filepath, title, notes=None):
    """Write markdown table."""
    with open(filepath, 'w') as f:
        f.write(f"# {title}\n\n")
        f.write('| ' + ' | '.join(headers) + ' |\n')
        f.write('|' + '|'.join(['---'] * len(headers)) + '|\n')
        for row in rows:
            f.write('| ' + ' | '.join(str(c) for c in row) + ' |\n')
        if notes:
            f.write(f"\n{notes}\n")
    print(f"  Wrote: {filepath.name}")


# ── Load Data ────────────────────────────────────────────────────────────

def load_data():
    """Load trilemma panel and download CA components from WDI."""
    print("Loading trilemma panel...")
    tri = pd.read_csv(DATA / "trilemma_panel.csv")

    # Try to get CA components from WDI
    print("Downloading CA components from World Bank WDI...")
    ca_components = download_ca_components(tri)
    if ca_components is not None:
        tri = tri.merge(ca_components, on=['iso3', 'year'], how='left')
        print(f"  Merged CA components: {ca_components.columns.tolist()}")

    return tri


def download_ca_components(tri):
    """Download trade balance, income balance, secondary income from WDI."""
    import urllib.request
    import io

    indicators = {
        'BN.GSR.GNFS.CD': 'trade_balance_usd',      # Net trade in goods & services
        'BN.GSR.FCTY.CD': 'primary_income_usd',      # Net primary income
        'BN.TRF.CURR.CD': 'secondary_income_usd',    # Net secondary income (transfers)
    }

    all_dfs = []
    for code, name in indicators.items():
        url = f"https://api.worldbank.org/v2/country/all/indicator/{code}?format=csv&per_page=20000&date=1995:2024"
        # Use the bulk download approach via the simple JSON API
        json_url = f"https://api.worldbank.org/v2/country/all/indicator/{code}?format=json&per_page=20000&date=1995:2024"
        try:
            req = urllib.request.Request(json_url)
            req.add_header('User-Agent', 'Mozilla/5.0')
            with urllib.request.urlopen(req, timeout=30) as resp:
                import json
                data = json.loads(resp.read().decode())
                if len(data) < 2 or data[1] is None:
                    print(f"  WDI {code}: no data returned")
                    continue
                records = []
                for item in data[1]:
                    if item['value'] is not None:
                        records.append({
                            'iso3': item['countryiso3code'],
                            'year': int(item['date']),
                            name: float(item['value'])
                        })
                df = pd.DataFrame(records)
                all_dfs.append(df)
                print(f"  WDI {code} ({name}): {len(df)} obs")
        except Exception as e:
            print(f"  WDI {code} failed: {e}")

    if not all_dfs:
        print("  No WDI data downloaded, using fallback approach")
        return create_ca_components_from_existing(tri)

    # Merge all indicators
    result = all_dfs[0]
    for df in all_dfs[1:]:
        result = result.merge(df, on=['iso3', 'year'], how='outer')

    # Convert to % of GDP using ngdp_usd from trilemma panel
    gdp = tri[['iso3', 'year', 'ngdp_usd']].dropna()
    result = result.merge(gdp, on=['iso3', 'year'], how='inner')

    for usd_col in ['trade_balance_usd', 'primary_income_usd', 'secondary_income_usd']:
        gdp_col = usd_col.replace('_usd', '_gdp')
        if usd_col in result.columns:
            # WDI values in current USD, ngdp_usd in billions — adjust
            result[gdp_col] = (result[usd_col] / (result['ngdp_usd'] * 1e9)) * 100

    result = result.drop(columns=['ngdp_usd'] + [c for c in result.columns if c.endswith('_usd')], errors='ignore')
    return result


def create_ca_components_from_existing(tri):
    """Fallback: derive trade balance and non-trade residual from existing data."""
    # trade_openness = (exports + imports) / GDP, but we need net trade
    # Use savings-investment gap as a proxy decomposition
    df = tri[['iso3', 'year', 'ca_gdp', 'gross_savings_gdp', 'gross_investment_gdp']].dropna()
    df['si_gap'] = df['gross_savings_gdp'] - df['gross_investment_gdp']
    df['non_si_residual'] = df['ca_gdp'] - df['si_gap']
    return df[['iso3', 'year', 'si_gap', 'non_si_residual']]


# ══════════════════════════════════════════════════════════════════════════
# SECTION 1: CA Component Decomposition Within Monetary Unions
# ══════════════════════════════════════════════════════════════════════════

def section1_ca_decomposition(tri):
    """Test which CA components carry the demographic signal within EMU and CFA."""
    print("\n" + "=" * 70)
    print("SECTION 1: CA Component Decomposition Within Monetary Unions")
    print("=" * 70)

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
    z_vars = ['Z_1', 'Z_2', 'Z_3']
    x_vars = z_vars + controls

    # Check which CA components are available
    ca_dvs = []
    for col in ['ca_gdp', 'trade_balance_gdp', 'primary_income_gdp', 'secondary_income_gdp']:
        if col in tri.columns:
            ca_dvs.append(col)

    # Also always include savings-investment decomposition
    if 'si_gap' not in tri.columns:
        tri['si_gap'] = tri['gross_savings_gdp'] - tri['gross_investment_gdp']
        tri['non_si_residual'] = tri['ca_gdp'] - tri['si_gap']
    if 'si_gap' not in ca_dvs:
        ca_dvs.extend(['si_gap', 'non_si_residual'])

    results_all = []

    for union_name, members_dict in [('EMU', EUROZONE_JOIN), ('CFA', CFA_JOIN)]:
        print(f"\n--- {union_name} ---")
        union_df = filter_union(tri, members_dict)
        union_df = add_z_deviations(union_df)

        if len(union_df) < 50:
            print(f"  {union_name}: only {len(union_df)} obs, skipping")
            continue

        # Use Z_dev within union
        z_dev_vars = [f'{z}_dev' for z in z_vars if f'{z}_dev' in union_df.columns]
        xvars_dev = z_dev_vars + controls

        for dv in ca_dvs:
            if dv not in union_df.columns:
                continue
            label = f"{union_name}: Z_dev → {dv}"
            r = run_gls(union_df, dv, xvars_dev, label)
            if r:
                r['union'] = union_name
                results_all.append(r)

    # Also run OECD floaters as comparison
    print("\n--- OECD Floaters (comparison) ---")
    floaters = tri[(tri['iso3'].isin(OECD)) & (~tri['iso3'].isin(EUROZONE_ISO3))].copy()
    floaters = add_z_deviations(floaters)
    z_dev_vars = [f'{z}_dev' for z in z_vars if f'{z}_dev' in floaters.columns]
    xvars_dev = z_dev_vars + controls
    for dv in ca_dvs:
        if dv not in floaters.columns:
            continue
        label = f"OECD Floaters: Z_dev → {dv}"
        r = run_gls(floaters, dv, xvars_dev, label)
        if r:
            r['union'] = 'OECD_Float'
            results_all.append(r)

    # Write table
    if results_all:
        headers = ['Sample', 'Dep. Var', 'Z_1_dev', '', 'N', 'Countries', 'R²']
        rows = []
        for r in results_all:
            z1_key = 'Z_1_dev_coef'
            if z1_key in r:
                coef_str, se_str = fmt(r[z1_key], r['Z_1_dev_se'], r['Z_1_dev_p'])
            else:
                coef_str, se_str = '—', ''
            rows.append([
                r['union'], r['dep_var'],
                coef_str, se_str,
                r['n_obs'], r['n_countries'], f"{r['r_squared']:.3f}"
            ])
        write_table(rows, headers, OUT_TABLES / "phase12_ca_decomposition.md",
                    "Within-Union CA Component Decomposition",
                    "*Notes: PanelGLS with country and year FE. Z_dev = within-union cross-sectional demeaning. Controls: fiscal balance, lagged NFA, growth, relative productivity, KAOPEN.*\n*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    return results_all


# ══════════════════════════════════════════════════════════════════════════
# SECTION 2: Savings vs Investment Adjustment
# ══════════════════════════════════════════════════════════════════════════

def section2_savings_investment(tri):
    """Which side of S-I adjusts to demographics within EMU?"""
    print("\n" + "=" * 70)
    print("SECTION 2: Savings vs Investment Adjustment Within Unions")
    print("=" * 70)

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
    z_vars = ['Z_1', 'Z_2', 'Z_3']

    dvs = ['gross_savings_gdp', 'gross_investment_gdp', 'savings_investment_gap']

    results_all = []

    for union_name, members_dict in [('EMU', EUROZONE_JOIN), ('CFA', CFA_JOIN)]:
        print(f"\n--- {union_name} ---")
        union_df = filter_union(tri, members_dict)
        union_df = add_z_deviations(union_df)
        z_dev_vars = [f'{z}_dev' for z in z_vars if f'{z}_dev' in union_df.columns]
        xvars = z_dev_vars + controls

        for dv in dvs:
            if dv not in union_df.columns:
                continue
            label = f"{union_name}: Z_dev → {dv}"
            r = run_gls(union_df, dv, xvars, label)
            if r:
                r['union'] = union_name
                results_all.append(r)

    # OECD floaters comparison
    print("\n--- OECD Floaters ---")
    floaters = tri[(tri['iso3'].isin(OECD)) & (~tri['iso3'].isin(EUROZONE_ISO3))].copy()
    floaters = add_z_deviations(floaters)
    z_dev_vars = [f'{z}_dev' for z in z_vars if f'{z}_dev' in floaters.columns]
    xvars = z_dev_vars + controls
    for dv in dvs:
        if dv not in floaters.columns:
            continue
        label = f"OECD Float: Z_dev → {dv}"
        r = run_gls(floaters, dv, xvars, label)
        if r:
            r['union'] = 'OECD_Float'
            results_all.append(r)

    if results_all:
        headers = ['Sample', 'Dep. Var', 'Z_1_dev', '', 'N', 'Countries', 'R²']
        rows = []
        for r in results_all:
            z1_key = 'Z_1_dev_coef'
            if z1_key in r:
                coef_str, se_str = fmt(r[z1_key], r['Z_1_dev_se'], r['Z_1_dev_p'])
            else:
                coef_str, se_str = '—', ''
            rows.append([
                r['union'], r['dep_var'],
                coef_str, se_str,
                r['n_obs'], r['n_countries'], f"{r['r_squared']:.3f}"
            ])
        write_table(rows, headers, OUT_TABLES / "phase12_si_adjustment.md",
                    "Savings vs Investment Adjustment Within Unions",
                    "*Notes: PanelGLS with country and year FE. Z_dev = within-union demeaning. Controls: fiscal balance, lagged NFA, growth, relative productivity, KAOPEN.*\n*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    return results_all


# ══════════════════════════════════════════════════════════════════════════
# SECTION 3: Bilateral Capital Flow Adjustment Within EMU
# ══════════════════════════════════════════════════════════════════════════

def section3_bilateral_flows(tri):
    """Do within-EMU demographic differences predict bilateral portfolio reallocation?"""
    print("\n" + "=" * 70)
    print("SECTION 3: Bilateral Capital Flows Within EMU")
    print("=" * 70)

    # Load bilateral panel
    print("Loading bilateral panel...")
    bp = pd.read_csv(BILATERAL_DATA / "bilateral_panel.csv")

    # Filter to within-EMU pairs
    emu_bp = bp[(bp['reporter'].isin(EUROZONE_ISO3)) &
                (bp['partner'].isin(EUROZONE_ISO3))].copy()
    print(f"  Within-EMU bilateral obs: {len(emu_bp)}")

    # Create demographic distance measures
    emu_bp['dZ_1_abs'] = emu_bp['dZ_1'].abs()
    emu_bp['dZ_2_abs'] = emu_bp['dZ_2'].abs()
    emu_bp['dZ_3_abs'] = emu_bp['dZ_3'].abs()

    # Create pair_id if not present
    if 'pair_id' not in emu_bp.columns:
        emu_bp['pair_id'] = emu_bp['reporter'] + '_' + emu_bp['partner']

    # Dependent variables: log bilateral flows
    flow_dvs = {
        'log_portfolio_total': 'Portfolio Total',
        'log_portfolio_debt': 'Portfolio Debt',
        'log_portfolio_equity': 'Portfolio Equity',
        'log_fdi_outward': 'FDI',
    }

    # Bilateral controls
    bilateral_controls = ['log_dist', 'log_gdp_product']
    # Add contiguity and common language if available
    for c in ['contiguity', 'common_lang_official']:
        if c in emu_bp.columns:
            bilateral_controls.append(c)

    results_all = []

    # Test 1: Signed demographic distance → directed flows
    print("\n  Test 1: Signed dZ → bilateral flows (within EMU)")
    x_signed = ['dZ_1', 'dZ_2', 'dZ_3'] + bilateral_controls
    for dv, label in flow_dvs.items():
        if dv not in emu_bp.columns:
            continue
        r = run_bilateral_gls(emu_bp, dv, x_signed, f"EMU signed: dZ → {label}")
        if r:
            r['test'] = 'signed_dZ'
            results_all.append(r)

    # Test 2: Absolute demographic distance → bilateral flows
    print("\n  Test 2: |dZ| → bilateral flows (within EMU)")
    x_abs = ['dZ_1_abs', 'dZ_2_abs', 'dZ_3_abs'] + bilateral_controls
    for dv, label in flow_dvs.items():
        if dv not in emu_bp.columns:
            continue
        r = run_bilateral_gls(emu_bp, dv, x_abs, f"EMU absolute: |dZ| → {label}")
        if r:
            r['test'] = 'abs_dZ'
            results_all.append(r)

    # Test 3: Non-EMU OECD comparison (within OECD floaters)
    print("\n  Test 3: dZ → bilateral flows (OECD non-EMU pairs, comparison)")
    non_emu_oecd = OECD - EUROZONE_ISO3
    oecd_bp = bp[(bp['reporter'].isin(non_emu_oecd)) &
                 (bp['partner'].isin(non_emu_oecd))].copy()
    if 'pair_id' not in oecd_bp.columns:
        oecd_bp['pair_id'] = oecd_bp['reporter'] + '_' + oecd_bp['partner']
    print(f"  Non-EMU OECD bilateral obs: {len(oecd_bp)}")

    for dv, label in flow_dvs.items():
        if dv not in oecd_bp.columns:
            continue
        r = run_bilateral_gls(oecd_bp, dv, x_signed, f"OECD non-EMU: dZ → {label}")
        if r:
            r['test'] = 'oecd_comparison'
            results_all.append(r)

    # Write results table
    if results_all:
        headers = ['Test', 'Dep. Var', 'dZ_1', '', 'N', 'Pairs', 'R²']
        rows = []
        for r in results_all:
            # Find the Z_1 coefficient (signed or absolute)
            for key_prefix in ['dZ_1_coef', 'dZ_1_abs_coef']:
                if key_prefix in r:
                    se_key = key_prefix.replace('_coef', '_se')
                    p_key = key_prefix.replace('_coef', '_p')
                    coef_str, se_str = fmt(r[key_prefix], r[se_key], r[p_key])
                    break
            else:
                coef_str, se_str = '—', ''

            rows.append([
                r['test'], r['dep_var'],
                coef_str, se_str,
                r['n_obs'], r.get('n_pairs', '—'), f"{r['r_squared']:.3f}"
            ])

        write_table(rows, headers, OUT_TABLES / "phase12_bilateral_flows.md",
                    "Bilateral Capital Flow Adjustment Within EMU",
                    "*Notes: PanelGLS with pair and year FE. dZ = reporter Z minus partner Z. Controls: log distance, log GDP product, contiguity, common language.*\n*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    return results_all


# ══════════════════════════════════════════════════════════════════════════
# SECTION 4: Fiscal Channel — Does Fiscal Balance Mediate?
# ══════════════════════════════════════════════════════════════════════════

def section4_fiscal_mediation(tri):
    """Test whether fiscal balance mediates the demographic-CA effect within EMU."""
    print("\n" + "=" * 70)
    print("SECTION 4: Fiscal Mediation Within Monetary Unions")
    print("=" * 70)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    base_controls = ['nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']

    results_all = []

    for union_name, members_dict in [('EMU', EUROZONE_JOIN), ('CFA', CFA_JOIN)]:
        print(f"\n--- {union_name} ---")
        union_df = filter_union(tri, members_dict)
        union_df = add_z_deviations(union_df)
        z_dev_vars = [f'{z}_dev' for z in z_vars if f'{z}_dev' in union_df.columns]

        # Step 1: Z_dev → fiscal_bal (first stage)
        label = f"{union_name}: Z_dev → fiscal_bal (1st stage)"
        r = run_gls(union_df, 'fiscal_bal_gdp', z_dev_vars + base_controls, label)
        if r:
            r['union'] = union_name
            r['step'] = 'first_stage'
            results_all.append(r)

        # Step 2a: Z_dev → CA without fiscal control
        label = f"{union_name}: Z_dev → CA (no fiscal)"
        r = run_gls(union_df, 'ca_gdp', z_dev_vars + base_controls, label)
        if r:
            r['union'] = union_name
            r['step'] = 'no_fiscal'
            results_all.append(r)

        # Step 2b: Z_dev → CA with fiscal control
        label = f"{union_name}: Z_dev → CA (with fiscal)"
        r = run_gls(union_df, 'ca_gdp', z_dev_vars + ['fiscal_bal_gdp'] + base_controls, label)
        if r:
            r['union'] = union_name
            r['step'] = 'with_fiscal'
            results_all.append(r)

        # Step 3: Z_dev → CA with fiscal + S-I gap
        if 'savings_investment_gap' in union_df.columns:
            label = f"{union_name}: Z_dev → CA (with fiscal + S-I)"
            r = run_gls(union_df, 'ca_gdp',
                        z_dev_vars + ['fiscal_bal_gdp', 'savings_investment_gap'] + base_controls, label)
            if r:
                r['union'] = union_name
                r['step'] = 'with_fiscal_si'
                results_all.append(r)

    if results_all:
        headers = ['Union', 'Step', 'Dep. Var', 'Z_1_dev', '', 'N', 'R²']
        rows = []
        for r in results_all:
            z1_key = 'Z_1_dev_coef'
            if z1_key in r:
                coef_str, se_str = fmt(r[z1_key], r['Z_1_dev_se'], r['Z_1_dev_p'])
            else:
                coef_str, se_str = '—', ''
            rows.append([
                r['union'], r['step'], r['dep_var'],
                coef_str, se_str,
                r['n_obs'], f"{r['r_squared']:.3f}"
            ])
        write_table(rows, headers, OUT_TABLES / "phase12_fiscal_mediation.md",
                    "Fiscal Mediation of Demographic CA Effect Within Unions",
                    "*Notes: PanelGLS with country and year FE. Mediation test: if Z_1_dev attenuates when fiscal_bal is added, fiscal channel mediates. Controls: lagged NFA, growth, relative productivity, KAOPEN.*\n*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    return results_all


# ══════════════════════════════════════════════════════════════════════════
# SECTION 5: Competitiveness / Price Adjustment Proxies
# ══════════════════════════════════════════════════════════════════════════

def section5_competitiveness(tri):
    """Test whether output per worker and growth adjust to demographics within EMU."""
    print("\n" + "=" * 70)
    print("SECTION 5: Competitiveness / Price Adjustment Within Unions")
    print("=" * 70)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'kaopen']

    # Competitiveness DVs: growth, productivity growth, inflation differential
    dvs = ['rgdp_growth', 'inflation']
    if 'output_per_worker' in tri.columns:
        # Create productivity growth
        tri = tri.sort_values(['iso3', 'year'])
        tri['productivity_growth'] = tri.groupby('iso3')['output_per_worker'].pct_change() * 100
        dvs.append('productivity_growth')

    results_all = []

    for union_name, members_dict in [('EMU', EUROZONE_JOIN), ('CFA', CFA_JOIN)]:
        print(f"\n--- {union_name} ---")
        union_df = filter_union(tri, members_dict)
        union_df = add_z_deviations(union_df)
        z_dev_vars = [f'{z}_dev' for z in z_vars if f'{z}_dev' in union_df.columns]
        xvars = z_dev_vars + controls

        # Also create inflation_dev and growth_dev (within-union deviations)
        for v in ['rgdp_growth', 'inflation', 'productivity_growth']:
            if v in union_df.columns:
                union_df[f'{v}_dev'] = union_df.groupby('year')[v].transform(lambda x: x - x.mean())

        for dv in dvs:
            if dv not in union_df.columns:
                continue
            # Use deviation from union mean as DV
            dv_dev = f'{dv}_dev'
            if dv_dev in union_df.columns:
                label = f"{union_name}: Z_dev → {dv}_dev"
                r = run_gls(union_df, dv_dev, z_dev_vars + controls, label)
                if r:
                    r['union'] = union_name
                    results_all.append(r)

    # Also test: does demographic divergence predict CA adjustment via growth?
    # Mediation: Z_dev → growth_dev → CA_dev
    print("\n--- Growth mediation of CA within EMU ---")
    emu_df = filter_union(tri, EUROZONE_JOIN)
    emu_df = add_z_deviations(emu_df)
    emu_df['ca_dev'] = emu_df.groupby('year')['ca_gdp'].transform(lambda x: x - x.mean())
    emu_df['growth_dev'] = emu_df.groupby('year')['rgdp_growth'].transform(lambda x: x - x.mean())

    z_dev_vars = [f'{z}_dev' for z in z_vars if f'{z}_dev' in emu_df.columns]

    # CA without growth
    r = run_gls(emu_df, 'ca_dev', z_dev_vars + ['fiscal_bal_gdp', 'nfa_gdp_lag', 'kaopen'],
                "EMU: Z_dev → CA_dev (no growth)")
    if r:
        r['union'] = 'EMU'
        r['step'] = 'no_growth'
        results_all.append(r)

    # CA with growth
    r = run_gls(emu_df, 'ca_dev', z_dev_vars + ['growth_dev', 'fiscal_bal_gdp', 'nfa_gdp_lag', 'kaopen'],
                "EMU: Z_dev → CA_dev (with growth)")
    if r:
        r['union'] = 'EMU'
        r['step'] = 'with_growth'
        results_all.append(r)

    if results_all:
        headers = ['Union', 'Dep. Var', 'Z_1_dev', '', 'N', 'R²']
        rows = []
        for r in results_all:
            z1_key = 'Z_1_dev_coef'
            if z1_key in r:
                coef_str, se_str = fmt(r[z1_key], r['Z_1_dev_se'], r['Z_1_dev_p'])
            else:
                coef_str, se_str = '—', ''
            rows.append([
                r['union'], r['dep_var'],
                coef_str, se_str,
                r['n_obs'], f"{r['r_squared']:.3f}"
            ])
        write_table(rows, headers, OUT_TABLES / "phase12_competitiveness.md",
                    "Competitiveness / Price Adjustment Within Unions",
                    "*Notes: PanelGLS with country and year FE. _dev variables are deviations from within-union annual means. Controls: fiscal balance, lagged NFA, KAOPEN.*\n*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    return results_all


# ══════════════════════════════════════════════════════════════════════════
# SECTION 6: Gross External Position Adjustment
# ══════════════════════════════════════════════════════════════════════════

def section6_gross_positions(tri):
    """Test which gross external positions adjust to demographics within EMU."""
    print("\n" + "=" * 70)
    print("SECTION 6: Gross External Position Adjustment Within Unions")
    print("=" * 70)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']

    position_dvs = [
        'gross_assets_gdp', 'gross_liab_gdp',
        'debt_assets_gdp', 'debt_liab_gdp',
        'port_eq_assets_gdp',
        'fdi_assets_gdp', 'fdi_liab_gdp',
        'fx_reserves_gdp',
    ]

    results_all = []

    for union_name, members_dict in [('EMU', EUROZONE_JOIN)]:
        print(f"\n--- {union_name} ---")
        union_df = filter_union(tri, members_dict)
        union_df = add_z_deviations(union_df)
        z_dev_vars = [f'{z}_dev' for z in z_vars if f'{z}_dev' in union_df.columns]
        xvars = z_dev_vars + controls

        for dv in position_dvs:
            if dv not in union_df.columns:
                continue
            label = f"{union_name}: Z_dev → {dv}"
            r = run_gls(union_df, dv, xvars, label)
            if r:
                r['union'] = union_name
                results_all.append(r)

    # OECD floaters comparison
    print("\n--- OECD Floaters ---")
    floaters = tri[(tri['iso3'].isin(OECD)) & (~tri['iso3'].isin(EUROZONE_ISO3))].copy()
    floaters = add_z_deviations(floaters)
    z_dev_vars = [f'{z}_dev' for z in z_vars if f'{z}_dev' in floaters.columns]
    xvars = z_dev_vars + controls
    for dv in position_dvs:
        if dv not in floaters.columns:
            continue
        label = f"OECD Float: Z_dev → {dv}"
        r = run_gls(floaters, dv, xvars, label)
        if r:
            r['union'] = 'OECD_Float'
            results_all.append(r)

    if results_all:
        headers = ['Sample', 'Dep. Var', 'Z_1_dev', '', 'N', 'Countries', 'R²']
        rows = []
        for r in results_all:
            z1_key = 'Z_1_dev_coef'
            if z1_key in r:
                coef_str, se_str = fmt(r[z1_key], r['Z_1_dev_se'], r['Z_1_dev_p'])
            else:
                coef_str, se_str = '—', ''
            rows.append([
                r['union'], r['dep_var'],
                coef_str, se_str,
                r['n_obs'], r['n_countries'], f"{r['r_squared']:.3f}"
            ])
        write_table(rows, headers, OUT_TABLES / "phase12_gross_positions.md",
                    "Gross External Position Adjustment Within Unions",
                    "*Notes: PanelGLS with country and year FE. Z_dev = within-union demeaning (OECD floaters: within-group demeaning). Controls: fiscal balance, lagged NFA, growth, relative productivity, KAOPEN.*\n*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    return results_all


# ══════════════════════════════════════════════════════════════════════════
# SECTION 7: Summary — Which Channels Adjust?
# ══════════════════════════════════════════════════════════════════════════

def section7_summary(all_results):
    """Create a summary table of which adjustment channels are active."""
    print("\n" + "=" * 70)
    print("SECTION 7: Summary of Adjustment Channels")
    print("=" * 70)

    # Collect key Z_1_dev coefficients across all sections
    summary_rows = []
    for section_name, results in all_results.items():
        for r in results:
            z1_key = 'Z_1_dev_coef'
            if z1_key not in r:
                # Try dZ_1 for bilateral
                for alt_key in ['dZ_1_coef', 'dZ_1_abs_coef']:
                    if alt_key in r:
                        z1_key = alt_key
                        break
                else:
                    continue

            se_key = z1_key.replace('_coef', '_se')
            p_key = z1_key.replace('_coef', '_p')
            coef = r[z1_key]
            p = r[p_key]
            sig = stars(p)

            union = r.get('union', r.get('test', ''))
            summary_rows.append({
                'section': section_name,
                'union': union,
                'dep_var': r['dep_var'],
                'Z_1_coef': coef,
                'Z_1_p': p,
                'sig': sig,
                'n_obs': r['n_obs'],
                'r_squared': r['r_squared'],
            })

    if summary_rows:
        headers = ['Section', 'Sample', 'Channel (DV)', 'Z_1', 'p-value', 'Sig', 'N']
        rows = []
        for s in summary_rows:
            rows.append([
                s['section'], s['union'], s['dep_var'],
                f"{s['Z_1_coef']:.3f}", f"{s['Z_1_p']:.3f}", s['sig'], s['n_obs']
            ])
        write_table(rows, headers, OUT_TABLES / "phase12_summary.md",
                    "Summary: Which Adjustment Channels Respond to Demographics Within Monetary Unions?",
                    "*Notes: Z_1 coefficients from within-union PanelGLS regressions. Positive Z_1 on CA means aging demographics → surplus; negative means aging → deficit. The question is which specific channels carry this effect.*\n*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")


# ══════════════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════════════

def main():
    print("Phase 12: Adjustment Channels Within Monetary Unions")
    print("=" * 70)

    tri = load_data()

    all_results = {}

    # Section 1: CA component decomposition
    all_results['CA_Components'] = section1_ca_decomposition(tri)

    # Section 2: Savings vs investment
    all_results['S_vs_I'] = section2_savings_investment(tri)

    # Section 3: Bilateral capital flows
    all_results['Bilateral'] = section3_bilateral_flows(tri)

    # Section 4: Fiscal mediation
    all_results['Fiscal'] = section4_fiscal_mediation(tri)

    # Section 5: Competitiveness proxies
    all_results['Competitiveness'] = section5_competitiveness(tri)

    # Section 6: Gross positions
    all_results['Gross_Positions'] = section6_gross_positions(tri)

    # Section 7: Summary
    section7_summary(all_results)

    print("\n" + "=" * 70)
    print("Phase 12 complete.")
    print(f"Tables written to: {OUT_TABLES}")
    print("=" * 70)


if __name__ == '__main__':
    main()
